
The Simpsons is one of the longest-running and most beloved animated series, featuring a vast and diverse array of characters. Predicting these characters from images is not only a fun and engaging task but also holds practical applications in media analysis, entertainment, and fan engagement.
Moreover, The Simpsons' style tends to use flat colors and simplistic character compositions, unlike realistic images which incorporate shadows, angles, profiles, and shapes. Given these differences, we are interested in determining whether training a model on The Simpsons' data would be more challenging or if achieving good performance would be relatively easier.
The number of data set (Data set link (Kaggle)):
Referenced code:
References:
The pre-trained models used:
import os
import tensorflow as tf
import keras
import matplotlib.image as mpimg
import numpy as np
import random
import seaborn as sns
import matplotlib.pyplot as plt
from keras.models import Sequential
from keras.layers import Dropout, Flatten, Dense, Conv2D, MaxPooling2D, Flatten, Dense
from tensorflow.keras.callbacks import EarlyStopping
from tensorflow.keras.preprocessing import image
from PIL import Image
from keras.preprocessing.image import ImageDataGenerator
from sklearn.metrics import confusion_matrix, classification_report
Split the data into train, test and validation sets
import splitfolders
input_dir = "./input/the-simpsons-characters-dataset/"
output_dir = "./data"
splitfolders.ratio(input_dir, output=output_dir, seed=66, ratio=(0.8, 0.1, 0.1))
Copying files: 0 files [00:00, ? files/s]Copying files: 20933 files [00:26, 778.39 files/s]
train_dir = "./data/train"
val_dir = "./data/val"
test_dir = "./data/test"
image_size = (180, 180, 3)
epochs = 50
batch_size = 100
labels = os.listdir(train_dir)
The ImageDataGenerator function would randomly augment the image without memory overload, it known as on-the-fly data augmentation. It would augment TotalNumberOfTrainingData/BatchSize, which is 16727/100 => 167 images for each epoch, to expand the variety of training samples without increasing storage demands
train_datagen = ImageDataGenerator(
rescale=1.0 / 255,
shear_range=0.2,
zoom_range=0.2,
horizontal_flip=True,
rotation_range=40,
width_shift_range=0.2,
height_shift_range=0.2,
)
test_datagen = ImageDataGenerator(rescale=1.0 / 255)
training_set = train_datagen.flow_from_directory(
train_dir,
target_size=image_size[:2],
batch_size=batch_size,
class_mode="categorical",
color_mode="rgb",
)
val_set = test_datagen.flow_from_directory(
val_dir,
target_size=image_size[:2],
batch_size=batch_size,
class_mode="categorical",
color_mode="rgb",
)
test_set = test_datagen.flow_from_directory(
test_dir,
target_size=image_size[:2],
batch_size=batch_size,
class_mode="categorical",
color_mode="rgb",
shuffle=False
)
Found 16727 images belonging to 42 classes. Found 2073 images belonging to 42 classes. Found 2133 images belonging to 42 classes.
img_path = "./data/train/homer_simpson/pic_0028.jpg"
img = image.load_img(img_path, target_size=(1080, 1080, 3))
plt.imshow(img)
plt.show()
classes = list(training_set.class_indices.keys())
train_dir = './data/train'
def plot_image(size):
plt.figure(figsize=(15, 18))
for i, class_1 in enumerate(classes):
folder_path = f"{train_dir}/{class_1}"
files = os.listdir(folder_path)
files = [f for f in files if os.path.isfile(os.path.join(folder_path, f))]
random_file = random.choice(files)
plt.subplot(size, size, i + 1)
img = mpimg.imread(f"{folder_path}/{random_file}")
plt.imshow(img)
plt.axis("off")
plt.title(' '.join(class_1.split("_")))
plt.tight_layout()
plt.show()
plot_image(7)
def plot_size(len_l):
plt.figure(figsize=(15, 15))
ax = sns.barplot(x=len_l, y=classes, orient="h", color="navy")
ax.set_xticks(np.arange(0, 2000, 2000))
ax.set_xlabel("Number of Images")
ax.set_ylabel("Classes")
ax.set_title("Number of samples for each class", fontsize=20)
for i, p in enumerate(ax.patches):
ax.text(
p.get_width(),
p.get_y() + p.get_height() / 2.0,
"{}".format(len_l[i]),
va="center",
fontsize=15,
)
len_list = []
for i, class_ in enumerate(classes):
len_list.append(len(os.listdir(f"{train_dir}/{class_}")))
plot_size(len_list)
def augment_image(image):
image = tf.image.random_flip_left_right(image)
# image = tf.image.random_flip_up_down(image)
image = tf.image.random_brightness(image, max_delta=0.2)
image = tf.image.random_contrast(image, lower=0.5, upper=1.5)
# image = tf.image.rot90(image, k=tf.random.uniform(shape=[], minval=0, maxval=4, dtype=tf.int32))
return image
for i, class_ in enumerate(classes):
if len_list[i] < 100:
folder_path = f"{train_dir}/{class_}"
files = os.listdir(folder_path)
seq = len_list[i]
for file_ in files:
for i in range(9):
img = mpimg.imread(f"{folder_path}/{file_}")
image = tf.io.read_file(f"{folder_path}/{file_}")
image = tf.image.decode_jpeg(image, channels=3)
augmented_image = augment_image(image)
img = Image.fromarray(augmented_image.numpy())
padded_number = str(seq).zfill(4)
img.save(os.path.join(folder_path, f"pic_{padded_number}.png"))
seq += 1
len_list_ = []
for i, class_ in enumerate(classes):
len_list_.append(len(os.listdir(f"{train_dir}/{class_}")))
plot_size(len_list_)
training_set = train_datagen.flow_from_directory(
train_dir,
target_size=image_size[:2],
batch_size=batch_size,
class_mode="categorical",
color_mode="rgb",
)
Found 24395 images belonging to 42 classes.
The following is referenced from Simpsons Image Classification -CNN | Val_acc=93%
model_cus = Sequential()
model_cus.add(Conv2D(32, (3, 3), padding="same", input_shape=image_size, activation="relu"))
model_cus.add(Conv2D(32, (3, 3), activation="relu"))
model_cus.add(MaxPooling2D(pool_size=(2, 2)))
model_cus.add(Dropout(0.2))
model_cus.add(Conv2D(64, (3, 3), padding="same", activation="relu"))
model_cus.add(Conv2D(64, (3, 3), activation="relu"))
model_cus.add(MaxPooling2D(pool_size=(2, 2)))
model_cus.add(Dropout(0.2))
model_cus.add(Conv2D(256, (3, 3), padding="same", activation="relu"))
model_cus.add(Conv2D(256, (3, 3), activation="relu"))
model_cus.add(MaxPooling2D(pool_size=(2, 2)))
model_cus.add(Dropout(0.2))
model_cus.add(Flatten())
model_cus.add(Dense(1024, activation="relu"))
model_cus.add(Dropout(0.5))
model_cus.add(Dense(len(labels), activation="softmax"))
model_cus.compile(loss="categorical_crossentropy", optimizer="adam", metrics=["accuracy"])
model_cus.summary()
Model: "sequential"
_________________________________________________________________
Layer (type) Output Shape Param #
=================================================================
conv2d (Conv2D) (None, 180, 180, 32) 896
conv2d_1 (Conv2D) (None, 178, 178, 32) 9248
max_pooling2d (MaxPooling2D (None, 89, 89, 32) 0
)
dropout (Dropout) (None, 89, 89, 32) 0
conv2d_2 (Conv2D) (None, 89, 89, 64) 18496
conv2d_3 (Conv2D) (None, 87, 87, 64) 36928
max_pooling2d_1 (MaxPooling (None, 43, 43, 64) 0
2D)
dropout_1 (Dropout) (None, 43, 43, 64) 0
conv2d_4 (Conv2D) (None, 43, 43, 256) 147712
conv2d_5 (Conv2D) (None, 41, 41, 256) 590080
max_pooling2d_2 (MaxPooling (None, 20, 20, 256) 0
2D)
dropout_2 (Dropout) (None, 20, 20, 256) 0
flatten (Flatten) (None, 102400) 0
dense (Dense) (None, 1024) 104858624
dropout_3 (Dropout) (None, 1024) 0
dense_1 (Dense) (None, 42) 43050
=================================================================
Total params: 105,705,034
Trainable params: 105,705,034
Non-trainable params: 0
_________________________________________________________________
checkpoint_filepath_cus = "./models/cus.h5"
model_cus_checkpoint_callback = tf.keras.callbacks.ModelCheckpoint(
filepath=checkpoint_filepath_cus,
save_weights_only=False,
monitor="val_accuracy",
mode="max",
save_best_only=True,
)
callbacks = [
EarlyStopping(patience=5, monitor="val_accuracy", mode="max"),
model_cus_checkpoint_callback,
]
history_customized = model_cus.fit(
training_set, epochs=epochs, validation_data=val_set, callbacks=callbacks
)
Epoch 1/50 244/244 [==============================] - 253s 967ms/step - loss: 3.5161 - accuracy: 0.0810 - val_loss: 2.9515 - val_accuracy: 0.2262 Epoch 2/50 244/244 [==============================] - 111s 451ms/step - loss: 3.0078 - accuracy: 0.1891 - val_loss: 2.3095 - val_accuracy: 0.3739 Epoch 3/50 244/244 [==============================] - 111s 451ms/step - loss: 2.5875 - accuracy: 0.2971 - val_loss: 1.8757 - val_accuracy: 0.5089 Epoch 4/50 244/244 [==============================] - 110s 449ms/step - loss: 2.2122 - accuracy: 0.3975 - val_loss: 1.4827 - val_accuracy: 0.5909 Epoch 5/50 244/244 [==============================] - 110s 450ms/step - loss: 1.8888 - accuracy: 0.4776 - val_loss: 1.2111 - val_accuracy: 0.6585 Epoch 6/50 244/244 [==============================] - 110s 449ms/step - loss: 1.6264 - accuracy: 0.5424 - val_loss: 0.9674 - val_accuracy: 0.7236 Epoch 7/50 244/244 [==============================] - 110s 451ms/step - loss: 1.4421 - accuracy: 0.5960 - val_loss: 0.9006 - val_accuracy: 0.7472 Epoch 8/50 244/244 [==============================] - 111s 452ms/step - loss: 1.3089 - accuracy: 0.6338 - val_loss: 0.7888 - val_accuracy: 0.7868 Epoch 9/50 244/244 [==============================] - 111s 452ms/step - loss: 1.1793 - accuracy: 0.6723 - val_loss: 0.7723 - val_accuracy: 0.8003 Epoch 10/50 244/244 [==============================] - 108s 441ms/step - loss: 1.0857 - accuracy: 0.6912 - val_loss: 0.7476 - val_accuracy: 0.7974 Epoch 11/50 244/244 [==============================] - 110s 451ms/step - loss: 1.0126 - accuracy: 0.7176 - val_loss: 0.6013 - val_accuracy: 0.8249 Epoch 12/50 244/244 [==============================] - 111s 452ms/step - loss: 0.9694 - accuracy: 0.7272 - val_loss: 0.5753 - val_accuracy: 0.8452 Epoch 13/50 244/244 [==============================] - 111s 451ms/step - loss: 0.8851 - accuracy: 0.7547 - val_loss: 0.5176 - val_accuracy: 0.8543 Epoch 14/50 244/244 [==============================] - 111s 452ms/step - loss: 0.8271 - accuracy: 0.7679 - val_loss: 0.5140 - val_accuracy: 0.8640 Epoch 15/50 244/244 [==============================] - 111s 455ms/step - loss: 0.7942 - accuracy: 0.7784 - val_loss: 0.4868 - val_accuracy: 0.8673 Epoch 16/50 244/244 [==============================] - 111s 452ms/step - loss: 0.7446 - accuracy: 0.7925 - val_loss: 0.4451 - val_accuracy: 0.8799 Epoch 17/50 244/244 [==============================] - 108s 441ms/step - loss: 0.7371 - accuracy: 0.7955 - val_loss: 0.4221 - val_accuracy: 0.8799 Epoch 18/50 244/244 [==============================] - 111s 453ms/step - loss: 0.6932 - accuracy: 0.8059 - val_loss: 0.4021 - val_accuracy: 0.8910 Epoch 19/50 244/244 [==============================] - 111s 452ms/step - loss: 0.6523 - accuracy: 0.8150 - val_loss: 0.4047 - val_accuracy: 0.8915 Epoch 20/50 244/244 [==============================] - 108s 440ms/step - loss: 0.6475 - accuracy: 0.8170 - val_loss: 0.4121 - val_accuracy: 0.8900 Epoch 21/50 244/244 [==============================] - 111s 453ms/step - loss: 0.6202 - accuracy: 0.8271 - val_loss: 0.3896 - val_accuracy: 0.9006 Epoch 22/50 244/244 [==============================] - 108s 440ms/step - loss: 0.5927 - accuracy: 0.8335 - val_loss: 0.4148 - val_accuracy: 0.8953 Epoch 23/50 244/244 [==============================] - 110s 451ms/step - loss: 0.5646 - accuracy: 0.8394 - val_loss: 0.3565 - val_accuracy: 0.9011 Epoch 24/50 244/244 [==============================] - 112s 457ms/step - loss: 0.5653 - accuracy: 0.8414 - val_loss: 0.3231 - val_accuracy: 0.9074 Epoch 25/50 244/244 [==============================] - 108s 439ms/step - loss: 0.5370 - accuracy: 0.8479 - val_loss: 0.3804 - val_accuracy: 0.9055 Epoch 26/50 244/244 [==============================] - 111s 452ms/step - loss: 0.5195 - accuracy: 0.8517 - val_loss: 0.3456 - val_accuracy: 0.9112 Epoch 27/50 244/244 [==============================] - 108s 442ms/step - loss: 0.5034 - accuracy: 0.8591 - val_loss: 0.3566 - val_accuracy: 0.9011 Epoch 28/50 244/244 [==============================] - 112s 456ms/step - loss: 0.5079 - accuracy: 0.8561 - val_loss: 0.3174 - val_accuracy: 0.9137 Epoch 29/50 244/244 [==============================] - 108s 442ms/step - loss: 0.5035 - accuracy: 0.8583 - val_loss: 0.3546 - val_accuracy: 0.9026 Epoch 30/50 244/244 [==============================] - 108s 443ms/step - loss: 0.4676 - accuracy: 0.8668 - val_loss: 0.3161 - val_accuracy: 0.9122 Epoch 31/50 244/244 [==============================] - 112s 457ms/step - loss: 0.4546 - accuracy: 0.8722 - val_loss: 0.3156 - val_accuracy: 0.9146 Epoch 32/50 244/244 [==============================] - 110s 449ms/step - loss: 0.4815 - accuracy: 0.8641 - val_loss: 0.3149 - val_accuracy: 0.9161 Epoch 33/50 244/244 [==============================] - 108s 440ms/step - loss: 0.4420 - accuracy: 0.8762 - val_loss: 0.3658 - val_accuracy: 0.9055 Epoch 34/50 244/244 [==============================] - 108s 440ms/step - loss: 0.4534 - accuracy: 0.8720 - val_loss: 0.3704 - val_accuracy: 0.9122 Epoch 35/50 244/244 [==============================] - 109s 444ms/step - loss: 0.4304 - accuracy: 0.8778 - val_loss: 0.3176 - val_accuracy: 0.9161 Epoch 36/50 244/244 [==============================] - 118s 483ms/step - loss: 0.4246 - accuracy: 0.8789 - val_loss: 0.2930 - val_accuracy: 0.9228 Epoch 37/50 244/244 [==============================] - 110s 448ms/step - loss: 0.4324 - accuracy: 0.8771 - val_loss: 0.2906 - val_accuracy: 0.9257 Epoch 38/50 244/244 [==============================] - 107s 437ms/step - loss: 0.4279 - accuracy: 0.8802 - val_loss: 0.2919 - val_accuracy: 0.9204 Epoch 39/50 244/244 [==============================] - 108s 443ms/step - loss: 0.4302 - accuracy: 0.8785 - val_loss: 0.3516 - val_accuracy: 0.9059 Epoch 40/50 244/244 [==============================] - 108s 442ms/step - loss: 0.4020 - accuracy: 0.8874 - val_loss: 0.3033 - val_accuracy: 0.9209 Epoch 41/50 244/244 [==============================] - 108s 440ms/step - loss: 0.4164 - accuracy: 0.8831 - val_loss: 0.2951 - val_accuracy: 0.9194 Epoch 42/50 244/244 [==============================] - 111s 454ms/step - loss: 0.3968 - accuracy: 0.8892 - val_loss: 0.2789 - val_accuracy: 0.9296 Epoch 43/50 244/244 [==============================] - 107s 438ms/step - loss: 0.4069 - accuracy: 0.8845 - val_loss: 0.2901 - val_accuracy: 0.9257 Epoch 44/50 244/244 [==============================] - 109s 446ms/step - loss: 0.3912 - accuracy: 0.8913 - val_loss: 0.2981 - val_accuracy: 0.9320 Epoch 45/50 244/244 [==============================] - 108s 440ms/step - loss: 0.3743 - accuracy: 0.8937 - val_loss: 0.2791 - val_accuracy: 0.9301 Epoch 46/50 244/244 [==============================] - 107s 438ms/step - loss: 0.3811 - accuracy: 0.8919 - val_loss: 0.2658 - val_accuracy: 0.9305 Epoch 47/50 244/244 [==============================] - 107s 439ms/step - loss: 0.3773 - accuracy: 0.8949 - val_loss: 0.2787 - val_accuracy: 0.9281 Epoch 48/50 244/244 [==============================] - 108s 441ms/step - loss: 0.3742 - accuracy: 0.8958 - val_loss: 0.2779 - val_accuracy: 0.9257 Epoch 49/50 244/244 [==============================] - 108s 443ms/step - loss: 0.3736 - accuracy: 0.8965 - val_loss: 0.2718 - val_accuracy: 0.9276
Trends for accuracy and loss: training dataset versus validation set
Load highest val-accuracy model
best_from_custome_customized_model = keras.models.load_model(checkpoint_filepath_cus)
Table below is referenced from Keras (https://keras.io/api/applications/)
| Model | Size (MB) | Top-1 Accuracy | Top-5 Accuracy | Parameters | Depth | Time (ms) per inference step (CPU) | Time (ms) per inference step (GPU) |
|---|---|---|---|---|---|---|---|
| Xception | 88 | 79.0% | 94.5% | 22.9M | 81 | 109.4 | 8.1 |
| VGG16 | 528 | 71.3% | 90.1% | 138.4M | 16 | 69.5 | 4.2 |
| VGG19 | 549 | 71.3% | 90.0% | 143.7M | 19 | 84.8 | 4.4 |
| ResNet50 | 98 | 74.9% | 92.1% | 25.6M | 107 | 58.2 | 4.6 |
| ResNet50V2 | 98 | 76.0% | 93.0% | 25.6M | 103 | 45.6 | 4.4 |
| ResNet101 | 171 | 76.4% | 92.8% | 44.7M | 209 | 89.6 | 5.2 |
| ResNet101V2 | 171 | 77.2% | 93.8% | 44.7M | 205 | 72.7 | 5.4 |
| ResNet152 | 232 | 76.6% | 93.1% | 60.4M | 311 | 127.4 | 6.5 |
| ResNet152V2 | 232 | 78.0% | 94.2% | 60.4M | 307 | 107.5 | 6.6 |
| InceptionV3 | 92 | 77.9% | 93.7% | 23.9M | 189 | 42.2 | 6.9 |
| InceptionResNetV2 | 215 | 80.3% | 95.3% | 55.9M | 449 | 130.2 | 10.0 |
| MobileNet | 16 | 70.4% | 89.5% | 4.3M | 55 | 22.6 | 3.4 |
| MobileNetV2 | 14 | 71.3% | 90.1% | 3.5M | 105 | 25.9 | 3.8 |
| DenseNet121 | 33 | 75.0% | 92.3% | 8.1M | 242 | 77.1 | 5.4 |
| DenseNet169 | 57 | 76.2% | 93.2% | 14.3M | 338 | 96.4 | 6.3 |
| DenseNet201 | 80 | 77.3% | 93.6% | 20.2M | 402 | 127.2 | 6.7 |
| NASNetMobile | 23 | 74.4% | 91.9% | 5.3M | 389 | 27.0 | 6.7 |
| NASNetLarge | 343 | 82.5% | 96.0% | 88.9M | 533 | 344.5 | 20.0 |
| EfficientNetB0 | 29 | 77.1% | 93.3% | 5.3M | 132 | 46.0 | 4.9 |
| EfficientNetB1 | 31 | 79.1% | 94.4% | 7.9M | 186 | 60.2 | 5.6 |
| EfficientNetB2 | 36 | 80.1% | 94.9% | 9.2M | 186 | 80.8 | 6.5 |
| EfficientNetB3 | 48 | 81.6% | 95.7% | 12.3M | 210 | 140.0 | 8.8 |
| EfficientNetB4 | 75 | 82.9% | 96.4% | 19.5M | 258 | 308.3 | 15.1 |
| EfficientNetB5 | 118 | 83.6% | 96.7% | 30.6M | 312 | 579.2 | 25.3 |
| EfficientNetB6 | 166 | 84.0% | 96.8% | 43.3M | 360 | 958.1 | 40.4 |
| EfficientNetB7 | 256 | 84.3% | 97.0% | 66.7M | 438 | 1578.9 | 61.6 |
| EfficientNetV2B0 | 29 | 78.7% | 94.3% | 7.2M | - | - | - |
| EfficientNetV2B1 | 34 | 79.8% | 95.0% | 8.2M | - | - | - |
| EfficientNetV2B2 | 42 | 80.5% | 95.1% | 10.2M | - | - | - |
| EfficientNetV2B3 | 59 | 82.0% | 95.8% | 14.5M | - | - | - |
| EfficientNetV2S | 88 | 83.9% | 96.7% | 21.6M | - | - | - |
| EfficientNetV2M | 220 | 85.3% | 97.4% | 54.4M | - | - | - |
| EfficientNetV2L | 479 | 85.7% | 97.5% | 119.0M | - | - | - |
| ConvNeXtTiny | 109.42 | 81.3% | - | 28.6M | - | - | - |
| ConvNeXtSmall | 192.29 | 82.3% | - | 50.2M | - | - | - |
| ConvNeXtBase | 338.58 | 85.3% | - | 88.5M | - | - | - |
| ConvNeXtLarge | 755.07 | 86.3% | - | 197.7M | - | - | - |
| ConvNeXtXLarge | 1310 | 86.7% | - | 350.1M | - | - | - |
The following is referenced from Recognize the Simpsons TensorFlow (88% acc.)
# Load the pretained model
pretrained_model = tf.keras.applications.MobileNetV2(
input_shape=(224, 224, 3),
include_top=False,
weights='imagenet',
pooling='avg'
)
pretrained_model.trainable = False
inputs = pretrained_model.input
x = tf.keras.layers.Dense(128, activation="relu")(pretrained_model.output)
x = tf.keras.layers.Dense(1024, activation="relu")(x)
outputs = tf.keras.layers.Dense(42, activation="softmax")(x)
model_MNV2 = tf.keras.Model(inputs=inputs, outputs=outputs)
model_MNV2.compile(
optimizer="adam", loss="categorical_crossentropy", metrics=["accuracy"]
)
model_MNV2.summary()
Model: "model"
__________________________________________________________________________________________________
Layer (type) Output Shape Param # Connected to
==================================================================================================
input_1 (InputLayer) [(None, 224, 224, 3 0 []
)]
Conv1 (Conv2D) (None, 112, 112, 32 864 ['input_1[0][0]']
)
bn_Conv1 (BatchNormalization) (None, 112, 112, 32 128 ['Conv1[0][0]']
)
Conv1_relu (ReLU) (None, 112, 112, 32 0 ['bn_Conv1[0][0]']
)
expanded_conv_depthwise (Depth (None, 112, 112, 32 288 ['Conv1_relu[0][0]']
wiseConv2D) )
expanded_conv_depthwise_BN (Ba (None, 112, 112, 32 128 ['expanded_conv_depthwise[0][0]']
tchNormalization) )
expanded_conv_depthwise_relu ( (None, 112, 112, 32 0 ['expanded_conv_depthwise_BN[0][0
ReLU) ) ]']
expanded_conv_project (Conv2D) (None, 112, 112, 16 512 ['expanded_conv_depthwise_relu[0]
) [0]']
expanded_conv_project_BN (Batc (None, 112, 112, 16 64 ['expanded_conv_project[0][0]']
hNormalization) )
block_1_expand (Conv2D) (None, 112, 112, 96 1536 ['expanded_conv_project_BN[0][0]'
) ]
block_1_expand_BN (BatchNormal (None, 112, 112, 96 384 ['block_1_expand[0][0]']
ization) )
block_1_expand_relu (ReLU) (None, 112, 112, 96 0 ['block_1_expand_BN[0][0]']
)
block_1_pad (ZeroPadding2D) (None, 113, 113, 96 0 ['block_1_expand_relu[0][0]']
)
block_1_depthwise (DepthwiseCo (None, 56, 56, 96) 864 ['block_1_pad[0][0]']
nv2D)
block_1_depthwise_BN (BatchNor (None, 56, 56, 96) 384 ['block_1_depthwise[0][0]']
malization)
block_1_depthwise_relu (ReLU) (None, 56, 56, 96) 0 ['block_1_depthwise_BN[0][0]']
block_1_project (Conv2D) (None, 56, 56, 24) 2304 ['block_1_depthwise_relu[0][0]']
block_1_project_BN (BatchNorma (None, 56, 56, 24) 96 ['block_1_project[0][0]']
lization)
block_2_expand (Conv2D) (None, 56, 56, 144) 3456 ['block_1_project_BN[0][0]']
block_2_expand_BN (BatchNormal (None, 56, 56, 144) 576 ['block_2_expand[0][0]']
ization)
block_2_expand_relu (ReLU) (None, 56, 56, 144) 0 ['block_2_expand_BN[0][0]']
block_2_depthwise (DepthwiseCo (None, 56, 56, 144) 1296 ['block_2_expand_relu[0][0]']
nv2D)
block_2_depthwise_BN (BatchNor (None, 56, 56, 144) 576 ['block_2_depthwise[0][0]']
malization)
block_2_depthwise_relu (ReLU) (None, 56, 56, 144) 0 ['block_2_depthwise_BN[0][0]']
block_2_project (Conv2D) (None, 56, 56, 24) 3456 ['block_2_depthwise_relu[0][0]']
block_2_project_BN (BatchNorma (None, 56, 56, 24) 96 ['block_2_project[0][0]']
lization)
block_2_add (Add) (None, 56, 56, 24) 0 ['block_1_project_BN[0][0]',
'block_2_project_BN[0][0]']
block_3_expand (Conv2D) (None, 56, 56, 144) 3456 ['block_2_add[0][0]']
block_3_expand_BN (BatchNormal (None, 56, 56, 144) 576 ['block_3_expand[0][0]']
ization)
block_3_expand_relu (ReLU) (None, 56, 56, 144) 0 ['block_3_expand_BN[0][0]']
block_3_pad (ZeroPadding2D) (None, 57, 57, 144) 0 ['block_3_expand_relu[0][0]']
block_3_depthwise (DepthwiseCo (None, 28, 28, 144) 1296 ['block_3_pad[0][0]']
nv2D)
block_3_depthwise_BN (BatchNor (None, 28, 28, 144) 576 ['block_3_depthwise[0][0]']
malization)
block_3_depthwise_relu (ReLU) (None, 28, 28, 144) 0 ['block_3_depthwise_BN[0][0]']
block_3_project (Conv2D) (None, 28, 28, 32) 4608 ['block_3_depthwise_relu[0][0]']
block_3_project_BN (BatchNorma (None, 28, 28, 32) 128 ['block_3_project[0][0]']
lization)
block_4_expand (Conv2D) (None, 28, 28, 192) 6144 ['block_3_project_BN[0][0]']
block_4_expand_BN (BatchNormal (None, 28, 28, 192) 768 ['block_4_expand[0][0]']
ization)
block_4_expand_relu (ReLU) (None, 28, 28, 192) 0 ['block_4_expand_BN[0][0]']
block_4_depthwise (DepthwiseCo (None, 28, 28, 192) 1728 ['block_4_expand_relu[0][0]']
nv2D)
block_4_depthwise_BN (BatchNor (None, 28, 28, 192) 768 ['block_4_depthwise[0][0]']
malization)
block_4_depthwise_relu (ReLU) (None, 28, 28, 192) 0 ['block_4_depthwise_BN[0][0]']
block_4_project (Conv2D) (None, 28, 28, 32) 6144 ['block_4_depthwise_relu[0][0]']
block_4_project_BN (BatchNorma (None, 28, 28, 32) 128 ['block_4_project[0][0]']
lization)
block_4_add (Add) (None, 28, 28, 32) 0 ['block_3_project_BN[0][0]',
'block_4_project_BN[0][0]']
block_5_expand (Conv2D) (None, 28, 28, 192) 6144 ['block_4_add[0][0]']
block_5_expand_BN (BatchNormal (None, 28, 28, 192) 768 ['block_5_expand[0][0]']
ization)
block_5_expand_relu (ReLU) (None, 28, 28, 192) 0 ['block_5_expand_BN[0][0]']
block_5_depthwise (DepthwiseCo (None, 28, 28, 192) 1728 ['block_5_expand_relu[0][0]']
nv2D)
block_5_depthwise_BN (BatchNor (None, 28, 28, 192) 768 ['block_5_depthwise[0][0]']
malization)
block_5_depthwise_relu (ReLU) (None, 28, 28, 192) 0 ['block_5_depthwise_BN[0][0]']
block_5_project (Conv2D) (None, 28, 28, 32) 6144 ['block_5_depthwise_relu[0][0]']
block_5_project_BN (BatchNorma (None, 28, 28, 32) 128 ['block_5_project[0][0]']
lization)
block_5_add (Add) (None, 28, 28, 32) 0 ['block_4_add[0][0]',
'block_5_project_BN[0][0]']
block_6_expand (Conv2D) (None, 28, 28, 192) 6144 ['block_5_add[0][0]']
block_6_expand_BN (BatchNormal (None, 28, 28, 192) 768 ['block_6_expand[0][0]']
ization)
block_6_expand_relu (ReLU) (None, 28, 28, 192) 0 ['block_6_expand_BN[0][0]']
block_6_pad (ZeroPadding2D) (None, 29, 29, 192) 0 ['block_6_expand_relu[0][0]']
block_6_depthwise (DepthwiseCo (None, 14, 14, 192) 1728 ['block_6_pad[0][0]']
nv2D)
block_6_depthwise_BN (BatchNor (None, 14, 14, 192) 768 ['block_6_depthwise[0][0]']
malization)
block_6_depthwise_relu (ReLU) (None, 14, 14, 192) 0 ['block_6_depthwise_BN[0][0]']
block_6_project (Conv2D) (None, 14, 14, 64) 12288 ['block_6_depthwise_relu[0][0]']
block_6_project_BN (BatchNorma (None, 14, 14, 64) 256 ['block_6_project[0][0]']
lization)
block_7_expand (Conv2D) (None, 14, 14, 384) 24576 ['block_6_project_BN[0][0]']
block_7_expand_BN (BatchNormal (None, 14, 14, 384) 1536 ['block_7_expand[0][0]']
ization)
block_7_expand_relu (ReLU) (None, 14, 14, 384) 0 ['block_7_expand_BN[0][0]']
block_7_depthwise (DepthwiseCo (None, 14, 14, 384) 3456 ['block_7_expand_relu[0][0]']
nv2D)
block_7_depthwise_BN (BatchNor (None, 14, 14, 384) 1536 ['block_7_depthwise[0][0]']
malization)
block_7_depthwise_relu (ReLU) (None, 14, 14, 384) 0 ['block_7_depthwise_BN[0][0]']
block_7_project (Conv2D) (None, 14, 14, 64) 24576 ['block_7_depthwise_relu[0][0]']
block_7_project_BN (BatchNorma (None, 14, 14, 64) 256 ['block_7_project[0][0]']
lization)
block_7_add (Add) (None, 14, 14, 64) 0 ['block_6_project_BN[0][0]',
'block_7_project_BN[0][0]']
block_8_expand (Conv2D) (None, 14, 14, 384) 24576 ['block_7_add[0][0]']
block_8_expand_BN (BatchNormal (None, 14, 14, 384) 1536 ['block_8_expand[0][0]']
ization)
block_8_expand_relu (ReLU) (None, 14, 14, 384) 0 ['block_8_expand_BN[0][0]']
block_8_depthwise (DepthwiseCo (None, 14, 14, 384) 3456 ['block_8_expand_relu[0][0]']
nv2D)
block_8_depthwise_BN (BatchNor (None, 14, 14, 384) 1536 ['block_8_depthwise[0][0]']
malization)
block_8_depthwise_relu (ReLU) (None, 14, 14, 384) 0 ['block_8_depthwise_BN[0][0]']
block_8_project (Conv2D) (None, 14, 14, 64) 24576 ['block_8_depthwise_relu[0][0]']
block_8_project_BN (BatchNorma (None, 14, 14, 64) 256 ['block_8_project[0][0]']
lization)
block_8_add (Add) (None, 14, 14, 64) 0 ['block_7_add[0][0]',
'block_8_project_BN[0][0]']
block_9_expand (Conv2D) (None, 14, 14, 384) 24576 ['block_8_add[0][0]']
block_9_expand_BN (BatchNormal (None, 14, 14, 384) 1536 ['block_9_expand[0][0]']
ization)
block_9_expand_relu (ReLU) (None, 14, 14, 384) 0 ['block_9_expand_BN[0][0]']
block_9_depthwise (DepthwiseCo (None, 14, 14, 384) 3456 ['block_9_expand_relu[0][0]']
nv2D)
block_9_depthwise_BN (BatchNor (None, 14, 14, 384) 1536 ['block_9_depthwise[0][0]']
malization)
block_9_depthwise_relu (ReLU) (None, 14, 14, 384) 0 ['block_9_depthwise_BN[0][0]']
block_9_project (Conv2D) (None, 14, 14, 64) 24576 ['block_9_depthwise_relu[0][0]']
block_9_project_BN (BatchNorma (None, 14, 14, 64) 256 ['block_9_project[0][0]']
lization)
block_9_add (Add) (None, 14, 14, 64) 0 ['block_8_add[0][0]',
'block_9_project_BN[0][0]']
block_10_expand (Conv2D) (None, 14, 14, 384) 24576 ['block_9_add[0][0]']
block_10_expand_BN (BatchNorma (None, 14, 14, 384) 1536 ['block_10_expand[0][0]']
lization)
block_10_expand_relu (ReLU) (None, 14, 14, 384) 0 ['block_10_expand_BN[0][0]']
block_10_depthwise (DepthwiseC (None, 14, 14, 384) 3456 ['block_10_expand_relu[0][0]']
onv2D)
block_10_depthwise_BN (BatchNo (None, 14, 14, 384) 1536 ['block_10_depthwise[0][0]']
rmalization)
block_10_depthwise_relu (ReLU) (None, 14, 14, 384) 0 ['block_10_depthwise_BN[0][0]']
block_10_project (Conv2D) (None, 14, 14, 96) 36864 ['block_10_depthwise_relu[0][0]']
block_10_project_BN (BatchNorm (None, 14, 14, 96) 384 ['block_10_project[0][0]']
alization)
block_11_expand (Conv2D) (None, 14, 14, 576) 55296 ['block_10_project_BN[0][0]']
block_11_expand_BN (BatchNorma (None, 14, 14, 576) 2304 ['block_11_expand[0][0]']
lization)
block_11_expand_relu (ReLU) (None, 14, 14, 576) 0 ['block_11_expand_BN[0][0]']
block_11_depthwise (DepthwiseC (None, 14, 14, 576) 5184 ['block_11_expand_relu[0][0]']
onv2D)
block_11_depthwise_BN (BatchNo (None, 14, 14, 576) 2304 ['block_11_depthwise[0][0]']
rmalization)
block_11_depthwise_relu (ReLU) (None, 14, 14, 576) 0 ['block_11_depthwise_BN[0][0]']
block_11_project (Conv2D) (None, 14, 14, 96) 55296 ['block_11_depthwise_relu[0][0]']
block_11_project_BN (BatchNorm (None, 14, 14, 96) 384 ['block_11_project[0][0]']
alization)
block_11_add (Add) (None, 14, 14, 96) 0 ['block_10_project_BN[0][0]',
'block_11_project_BN[0][0]']
block_12_expand (Conv2D) (None, 14, 14, 576) 55296 ['block_11_add[0][0]']
block_12_expand_BN (BatchNorma (None, 14, 14, 576) 2304 ['block_12_expand[0][0]']
lization)
block_12_expand_relu (ReLU) (None, 14, 14, 576) 0 ['block_12_expand_BN[0][0]']
block_12_depthwise (DepthwiseC (None, 14, 14, 576) 5184 ['block_12_expand_relu[0][0]']
onv2D)
block_12_depthwise_BN (BatchNo (None, 14, 14, 576) 2304 ['block_12_depthwise[0][0]']
rmalization)
block_12_depthwise_relu (ReLU) (None, 14, 14, 576) 0 ['block_12_depthwise_BN[0][0]']
block_12_project (Conv2D) (None, 14, 14, 96) 55296 ['block_12_depthwise_relu[0][0]']
block_12_project_BN (BatchNorm (None, 14, 14, 96) 384 ['block_12_project[0][0]']
alization)
block_12_add (Add) (None, 14, 14, 96) 0 ['block_11_add[0][0]',
'block_12_project_BN[0][0]']
block_13_expand (Conv2D) (None, 14, 14, 576) 55296 ['block_12_add[0][0]']
block_13_expand_BN (BatchNorma (None, 14, 14, 576) 2304 ['block_13_expand[0][0]']
lization)
block_13_expand_relu (ReLU) (None, 14, 14, 576) 0 ['block_13_expand_BN[0][0]']
block_13_pad (ZeroPadding2D) (None, 15, 15, 576) 0 ['block_13_expand_relu[0][0]']
block_13_depthwise (DepthwiseC (None, 7, 7, 576) 5184 ['block_13_pad[0][0]']
onv2D)
block_13_depthwise_BN (BatchNo (None, 7, 7, 576) 2304 ['block_13_depthwise[0][0]']
rmalization)
block_13_depthwise_relu (ReLU) (None, 7, 7, 576) 0 ['block_13_depthwise_BN[0][0]']
block_13_project (Conv2D) (None, 7, 7, 160) 92160 ['block_13_depthwise_relu[0][0]']
block_13_project_BN (BatchNorm (None, 7, 7, 160) 640 ['block_13_project[0][0]']
alization)
block_14_expand (Conv2D) (None, 7, 7, 960) 153600 ['block_13_project_BN[0][0]']
block_14_expand_BN (BatchNorma (None, 7, 7, 960) 3840 ['block_14_expand[0][0]']
lization)
block_14_expand_relu (ReLU) (None, 7, 7, 960) 0 ['block_14_expand_BN[0][0]']
block_14_depthwise (DepthwiseC (None, 7, 7, 960) 8640 ['block_14_expand_relu[0][0]']
onv2D)
block_14_depthwise_BN (BatchNo (None, 7, 7, 960) 3840 ['block_14_depthwise[0][0]']
rmalization)
block_14_depthwise_relu (ReLU) (None, 7, 7, 960) 0 ['block_14_depthwise_BN[0][0]']
block_14_project (Conv2D) (None, 7, 7, 160) 153600 ['block_14_depthwise_relu[0][0]']
block_14_project_BN (BatchNorm (None, 7, 7, 160) 640 ['block_14_project[0][0]']
alization)
block_14_add (Add) (None, 7, 7, 160) 0 ['block_13_project_BN[0][0]',
'block_14_project_BN[0][0]']
block_15_expand (Conv2D) (None, 7, 7, 960) 153600 ['block_14_add[0][0]']
block_15_expand_BN (BatchNorma (None, 7, 7, 960) 3840 ['block_15_expand[0][0]']
lization)
block_15_expand_relu (ReLU) (None, 7, 7, 960) 0 ['block_15_expand_BN[0][0]']
block_15_depthwise (DepthwiseC (None, 7, 7, 960) 8640 ['block_15_expand_relu[0][0]']
onv2D)
block_15_depthwise_BN (BatchNo (None, 7, 7, 960) 3840 ['block_15_depthwise[0][0]']
rmalization)
block_15_depthwise_relu (ReLU) (None, 7, 7, 960) 0 ['block_15_depthwise_BN[0][0]']
block_15_project (Conv2D) (None, 7, 7, 160) 153600 ['block_15_depthwise_relu[0][0]']
block_15_project_BN (BatchNorm (None, 7, 7, 160) 640 ['block_15_project[0][0]']
alization)
block_15_add (Add) (None, 7, 7, 160) 0 ['block_14_add[0][0]',
'block_15_project_BN[0][0]']
block_16_expand (Conv2D) (None, 7, 7, 960) 153600 ['block_15_add[0][0]']
block_16_expand_BN (BatchNorma (None, 7, 7, 960) 3840 ['block_16_expand[0][0]']
lization)
block_16_expand_relu (ReLU) (None, 7, 7, 960) 0 ['block_16_expand_BN[0][0]']
block_16_depthwise (DepthwiseC (None, 7, 7, 960) 8640 ['block_16_expand_relu[0][0]']
onv2D)
block_16_depthwise_BN (BatchNo (None, 7, 7, 960) 3840 ['block_16_depthwise[0][0]']
rmalization)
block_16_depthwise_relu (ReLU) (None, 7, 7, 960) 0 ['block_16_depthwise_BN[0][0]']
block_16_project (Conv2D) (None, 7, 7, 320) 307200 ['block_16_depthwise_relu[0][0]']
block_16_project_BN (BatchNorm (None, 7, 7, 320) 1280 ['block_16_project[0][0]']
alization)
Conv_1 (Conv2D) (None, 7, 7, 1280) 409600 ['block_16_project_BN[0][0]']
Conv_1_bn (BatchNormalization) (None, 7, 7, 1280) 5120 ['Conv_1[0][0]']
out_relu (ReLU) (None, 7, 7, 1280) 0 ['Conv_1_bn[0][0]']
global_average_pooling2d (Glob (None, 1280) 0 ['out_relu[0][0]']
alAveragePooling2D)
dense_2 (Dense) (None, 128) 163968 ['global_average_pooling2d[0][0]'
]
dense_3 (Dense) (None, 1024) 132096 ['dense_2[0][0]']
dense_4 (Dense) (None, 42) 43050 ['dense_3[0][0]']
==================================================================================================
Total params: 2,597,098
Trainable params: 339,114
Non-trainable params: 2,257,984
__________________________________________________________________________________________________
model_MNV2.compile(
optimizer="adam", loss="categorical_crossentropy", metrics=["accuracy"]
)
checkpoint_filepath_MNV2 = "./models/MNV2.h5"
model_MNV2_checkpoint_callback = tf.keras.callbacks.ModelCheckpoint(
filepath=checkpoint_filepath_MNV2,
save_weights_only=False,
monitor="val_accuracy",
mode="max",
save_best_only=True,
)
callbacks_MNV2 = [
model_MNV2_checkpoint_callback,
EarlyStopping(monitor="val_accuracy", patience=5, restore_best_weights=True),
]
history_MNV2 = model_MNV2.fit(
training_set, validation_data=val_set, epochs=50, callbacks=callbacks_MNV2
)
Epoch 1/50 244/244 [==============================] - 108s 434ms/step - loss: 2.4468 - accuracy: 0.3495 - val_loss: 1.8542 - val_accuracy: 0.4954 Epoch 2/50 244/244 [==============================] - 106s 435ms/step - loss: 1.7848 - accuracy: 0.5089 - val_loss: 1.7979 - val_accuracy: 0.5104 Epoch 3/50 244/244 [==============================] - 106s 434ms/step - loss: 1.5545 - accuracy: 0.5696 - val_loss: 1.5409 - val_accuracy: 0.5712 Epoch 4/50 244/244 [==============================] - 106s 433ms/step - loss: 1.4255 - accuracy: 0.5974 - val_loss: 1.5524 - val_accuracy: 0.5678 Epoch 5/50 244/244 [==============================] - 105s 429ms/step - loss: 1.3348 - accuracy: 0.6188 - val_loss: 1.4764 - val_accuracy: 0.5861 Epoch 6/50 244/244 [==============================] - 106s 433ms/step - loss: 1.2664 - accuracy: 0.6415 - val_loss: 1.4505 - val_accuracy: 0.5953 Epoch 7/50 244/244 [==============================] - 106s 433ms/step - loss: 1.2224 - accuracy: 0.6489 - val_loss: 1.4238 - val_accuracy: 0.5996 Epoch 8/50 244/244 [==============================] - 105s 432ms/step - loss: 1.1595 - accuracy: 0.6668 - val_loss: 1.4417 - val_accuracy: 0.6088 Epoch 9/50 244/244 [==============================] - 107s 437ms/step - loss: 1.1099 - accuracy: 0.6773 - val_loss: 1.3215 - val_accuracy: 0.6213 Epoch 10/50 244/244 [==============================] - 106s 433ms/step - loss: 1.0765 - accuracy: 0.6885 - val_loss: 1.3723 - val_accuracy: 0.6228 Epoch 11/50 244/244 [==============================] - 106s 433ms/step - loss: 1.0483 - accuracy: 0.6939 - val_loss: 1.3970 - val_accuracy: 0.6107 Epoch 12/50 244/244 [==============================] - 106s 433ms/step - loss: 1.0338 - accuracy: 0.6988 - val_loss: 1.3952 - val_accuracy: 0.6141 Epoch 13/50 244/244 [==============================] - 106s 435ms/step - loss: 1.0124 - accuracy: 0.7035 - val_loss: 1.3422 - val_accuracy: 0.6343 Epoch 14/50 244/244 [==============================] - 105s 432ms/step - loss: 0.9675 - accuracy: 0.7169 - val_loss: 1.3468 - val_accuracy: 0.6295 Epoch 15/50 244/244 [==============================] - 106s 435ms/step - loss: 0.9460 - accuracy: 0.7221 - val_loss: 1.2937 - val_accuracy: 0.6334 Epoch 16/50 244/244 [==============================] - 106s 436ms/step - loss: 0.9322 - accuracy: 0.7274 - val_loss: 1.3918 - val_accuracy: 0.6165 Epoch 17/50 244/244 [==============================] - 105s 432ms/step - loss: 0.9245 - accuracy: 0.7264 - val_loss: 1.5238 - val_accuracy: 0.5871 Epoch 18/50 244/244 [==============================] - 106s 435ms/step - loss: 0.9065 - accuracy: 0.7315 - val_loss: 1.2946 - val_accuracy: 0.6503 Epoch 19/50 244/244 [==============================] - 105s 432ms/step - loss: 0.9094 - accuracy: 0.7346 - val_loss: 1.3655 - val_accuracy: 0.6397 Epoch 20/50 244/244 [==============================] - 105s 430ms/step - loss: 0.8810 - accuracy: 0.7383 - val_loss: 1.3503 - val_accuracy: 0.6310 Epoch 21/50 244/244 [==============================] - 106s 433ms/step - loss: 0.8728 - accuracy: 0.7423 - val_loss: 1.3489 - val_accuracy: 0.6329 Epoch 22/50 244/244 [==============================] - 105s 431ms/step - loss: 0.8550 - accuracy: 0.7450 - val_loss: 1.3282 - val_accuracy: 0.6387 Epoch 23/50 244/244 [==============================] - 106s 434ms/step - loss: 0.8492 - accuracy: 0.7506 - val_loss: 1.3459 - val_accuracy: 0.6358
Load highest val-accuracy model - Mobile Net V2
best_from_custome_MNV2_model = keras.models.load_model(checkpoint_filepath_MNV2)
The following is referenced from CNN + ResNet
from tensorflow.keras.applications.resnet50 import ResNet50
model_RN50=Sequential()
model_RN50.add(ResNet50(include_top=False,
pooling='max'))
model_RN50.add(Flatten())
model_RN50.add(Dense(1024, activation="relu"))
model_RN50.add(Dropout(0.5))
model_RN50.add(Dense(len(labels), activation='softmax'))
model_RN50.summary()
Model: "sequential_1"
_________________________________________________________________
Layer (type) Output Shape Param #
=================================================================
resnet50 (Functional) (None, 2048) 23587712
flatten_1 (Flatten) (None, 2048) 0
dense_5 (Dense) (None, 1024) 2098176
dropout_4 (Dropout) (None, 1024) 0
dense_6 (Dense) (None, 42) 43050
=================================================================
Total params: 25,728,938
Trainable params: 25,675,818
Non-trainable params: 53,120
_________________________________________________________________
model_RN50.compile(loss="categorical_crossentropy", optimizer="adam", metrics=["accuracy"])
checkpoint_filepath_RN50 = "./models/RN50.h5"
model_RN50_checkpoint_callback = tf.keras.callbacks.ModelCheckpoint(
filepath=checkpoint_filepath_RN50,
save_weights_only=False,
monitor="val_accuracy",
mode="max",
save_best_only=True,
)
callbacks_MNV2 = [
model_RN50_checkpoint_callback,
EarlyStopping(monitor="val_accuracy", patience=5, restore_best_weights=True),
]
training_set_RN50 = train_datagen.flow_from_directory(
train_dir,
target_size=image_size[:2],
batch_size=35,
class_mode="categorical",
color_mode="rgb",
)
val_set_RN50 = test_datagen.flow_from_directory(
val_dir,
target_size=image_size[:2],
batch_size=35,
class_mode="categorical",
color_mode="rgb",
)
Found 24395 images belonging to 42 classes. Found 2073 images belonging to 42 classes.
history_RN50 = model_RN50.fit(
training_set_RN50, epochs=50, validation_data=val_set_RN50, callbacks=callbacks_MNV2
)
Epoch 1/50 697/697 [==============================] - 125s 173ms/step - loss: 3.2562 - accuracy: 0.2537 - val_loss: 4.0804 - val_accuracy: 0.0145 Epoch 2/50 697/697 [==============================] - 120s 172ms/step - loss: 1.6030 - accuracy: 0.5683 - val_loss: 1.8252 - val_accuracy: 0.5297 Epoch 3/50 697/697 [==============================] - 120s 173ms/step - loss: 1.3147 - accuracy: 0.6454 - val_loss: 0.8303 - val_accuracy: 0.7771 Epoch 4/50 697/697 [==============================] - 120s 172ms/step - loss: 1.1980 - accuracy: 0.6780 - val_loss: 5.2982 - val_accuracy: 0.0550 Epoch 5/50 697/697 [==============================] - 120s 172ms/step - loss: 1.0553 - accuracy: 0.7139 - val_loss: 2.3258 - val_accuracy: 0.4568 Epoch 6/50 697/697 [==============================] - 120s 172ms/step - loss: 0.8492 - accuracy: 0.7779 - val_loss: 1.0132 - val_accuracy: 0.7192 Epoch 7/50 697/697 [==============================] - 120s 172ms/step - loss: 1.1842 - accuracy: 0.6836 - val_loss: 1.1694 - val_accuracy: 0.6898 Epoch 8/50 697/697 [==============================] - 121s 173ms/step - loss: 0.8175 - accuracy: 0.7804 - val_loss: 0.6165 - val_accuracy: 0.8249 Epoch 9/50 697/697 [==============================] - 120s 173ms/step - loss: 0.7473 - accuracy: 0.7996 - val_loss: 3.4145 - val_accuracy: 0.1630 Epoch 10/50 697/697 [==============================] - 121s 174ms/step - loss: 0.7452 - accuracy: 0.8001 - val_loss: 0.6542 - val_accuracy: 0.8215 Epoch 11/50 697/697 [==============================] - 121s 173ms/step - loss: 0.7178 - accuracy: 0.8089 - val_loss: 1.9048 - val_accuracy: 0.5355 Epoch 12/50 697/697 [==============================] - 121s 174ms/step - loss: 0.7527 - accuracy: 0.7970 - val_loss: 0.3887 - val_accuracy: 0.8905 Epoch 13/50 697/697 [==============================] - 120s 172ms/step - loss: 0.9015 - accuracy: 0.7684 - val_loss: 4.5562 - val_accuracy: 0.2595 Epoch 14/50 697/697 [==============================] - 120s 173ms/step - loss: 0.7467 - accuracy: 0.7973 - val_loss: 0.4030 - val_accuracy: 0.8866 Epoch 15/50 697/697 [==============================] - 121s 173ms/step - loss: 0.6788 - accuracy: 0.8173 - val_loss: 1.0538 - val_accuracy: 0.7289 Epoch 16/50 697/697 [==============================] - 121s 173ms/step - loss: 0.7552 - accuracy: 0.7937 - val_loss: 0.7005 - val_accuracy: 0.8027 Epoch 17/50 697/697 [==============================] - 121s 173ms/step - loss: 0.6447 - accuracy: 0.8232 - val_loss: 0.8102 - val_accuracy: 0.7815
Load highest val-accuracy model - Res Net 50
best_from_custome_RN50_model = keras.models.load_model(checkpoint_filepath_RN50)
acc_customized = history_customized.history["accuracy"]
val_acc_customized = history_customized.history["val_accuracy"]
loss_customized = history_customized.history["loss"]
val_loss_customized = history_customized.history["val_loss"]
epochs_customized = range(1, len(acc_customized) + 1)
acc_MNV2 = history_MNV2.history["accuracy"]
val_acc_MNV2 = history_MNV2.history["val_accuracy"]
loss_MNV2 = history_MNV2.history["loss"]
val_loss_MNV2 = history_MNV2.history["val_loss"]
epochs_MNV2 = range(1, len(acc_MNV2) + 1)
acc_RN50 = history_RN50.history["accuracy"]
val_acc_RN50 = history_RN50.history["val_accuracy"]
loss_RN50 = history_RN50.history["loss"]
val_loss_RN50 = history_RN50.history["val_loss"]
epochs_RN50 = range(1, len(acc_RN50) + 1)
fig, axs = plt.subplots(1, 3, figsize=(24, 8))
axs[0].plot(epochs_customized, acc_customized, "bo", label="Training accuracy")
axs[0].plot(epochs_customized, val_acc_customized, "b", label="Validation accuracy")
axs[0].set_title("Training and validation accuracy trained by customized model")
axs[0].legend()
axs[1].plot(epochs_MNV2, acc_MNV2, "bo", label="Training accuracy")
axs[1].plot(epochs_MNV2, val_acc_MNV2, "b", label="Validation accuracy")
axs[1].set_title("Training and validation accuracy trained by Mobile Net V2")
axs[1].legend()
axs[2].plot(epochs_RN50, acc_RN50, "bo", label="Training accuracy")
axs[2].plot(epochs_RN50, val_acc_RN50, "b", label="Validation accuracy")
axs[2].set_title("Training and validation accuracy trained by ResNet50")
axs[2].legend()
<matplotlib.legend.Legend at 0x1c31bb43850>
Loss Comparison between these thee models
fig, axs = plt.subplots(1, 3, figsize=(24, 8))
axs[0].plot(epochs_customized, loss_customized, "bo", label="Training loss")
axs[0].plot(epochs_customized, val_loss_customized, "b", label="Validation loss")
axs[0].set_title("Training and validation loss trained by customized model")
axs[0].legend()
axs[1].plot(epochs_MNV2, loss_MNV2, "bo", label="Training loss")
axs[1].plot(epochs_MNV2, val_loss_MNV2, "b", label="Validation loss")
axs[1].set_title("Training and validation loss trained by Mobile Net V2")
axs[1].legend()
axs[2].plot(epochs_RN50, loss_RN50, "bo", label="Training loss")
axs[2].plot(epochs_RN50, val_loss_RN50, "b", label="Validation loss")
axs[2].set_title("Training and validation loss trained by ResNet50")
axs[2].legend()
<matplotlib.legend.Legend at 0x1c3324d7490>
checkpoint_filepath_cus = "./models/cus.h5"
checkpoint_filepath_MNV2 = "./models/MNV2.h5"
checkpoint_filepath_RN50 = "./models/RN50.h5"
best_from_custome_customized_model = keras.models.load_model(checkpoint_filepath_cus)
best_from_custome_MNV2_model = keras.models.load_model(checkpoint_filepath_MNV2)
best_from_custome_RN50_model = keras.models.load_model(checkpoint_filepath_RN50)
y_true = test_set.classes
y_test_true = keras.utils.to_categorical(y_true, num_classes=len(labels))
y_test_true = np.argmax(y_test_true, axis=1)
y_pred_customized = best_from_custome_customized_model.predict(test_set)
y_pred_customized = np.argmax(y_pred_customized, axis=1)
22/22 [==============================] - 10s 227ms/step
def find_unequal_indexes(list1, list2):
unequal_indices = []
min_length = min(len(list1), len(list2))
for i in range(min_length):
if list1[i] != list2[i]:
unequal_indices.append(i+1)
for i in range(min_length, max(len(list1), len(list2))):
unequal_indices.append(i+1)
return unequal_indices
Total number of mis-classification by customized model
unequal_indices_customized = find_unequal_indexes(y_test_true, y_pred_customized)
print(
f"Total number of mis-classification by customized model: {len(unequal_indices_customized)}, which is {round(len(unequal_indices_customized)/len(y_test_true)*100, 1)}%"
)
Total number of mis-classification by customized model: 168, which is 7.9%
The accuracy provided by the customized model is 92%, which is really good. For the most class's precision, recall and f1-score are around 80%-90%. However, for some classes, they didn't perform well, the precision from 0.35 to 0.6. I noticed that the class with relatively low support amount, often performed the more bad results.
class_report_customized = classification_report(
y_test_true, y_pred_customized, zero_division=1, target_names=labels
)
print(class_report_customized)
precision recall f1-score support
abraham_grampa_simpson 0.94 0.88 0.91 92
agnes_skinner 1.00 0.60 0.75 5
apu_nahasapeemapetilon 0.90 0.90 0.90 63
barney_gumble 0.90 0.75 0.82 12
bart_simpson 0.88 0.91 0.89 135
carl_carlson 0.58 1.00 0.73 11
charles_montgomery_burns 0.94 0.95 0.95 120
chief_wiggum 0.93 0.94 0.94 100
cletus_spuckler 0.80 0.67 0.73 6
comic_book_guy 0.84 0.88 0.86 48
disco_stu 1.00 0.50 0.67 2
edna_krabappel 0.95 0.79 0.86 47
fat_tony 1.00 1.00 1.00 4
gil 0.75 0.75 0.75 4
groundskeeper_willie 0.76 1.00 0.87 13
homer_simpson 0.92 0.96 0.94 226
kent_brockman 0.96 0.92 0.94 51
krusty_the_clown 0.92 0.98 0.95 122
lenny_leonard 0.97 0.90 0.93 31
lionel_hutz 1.00 1.00 1.00 1
lisa_simpson 0.97 0.86 0.91 136
maggie_simpson 1.00 0.57 0.73 14
marge_simpson 0.98 0.96 0.97 130
martin_prince 0.78 0.88 0.82 8
mayor_quimby 0.85 0.88 0.87 26
milhouse_van_houten 0.98 0.96 0.97 109
miss_hoover 0.75 1.00 0.86 3
moe_szyslak 0.95 0.93 0.94 146
ned_flanders 0.94 0.96 0.95 146
nelson_muntz 0.97 0.84 0.90 37
otto_mann 0.57 1.00 0.73 4
patty_bouvier 0.88 0.88 0.88 8
principal_skinner 0.95 0.94 0.95 120
professor_john_frink 0.35 0.86 0.50 7
rainier_wolfcastle 0.62 1.00 0.77 5
ralph_wiggum 1.00 0.90 0.95 10
selma_bouvier 0.71 0.91 0.80 11
sideshow_bob 0.94 0.96 0.95 89
sideshow_mel 0.60 0.75 0.67 4
snake_jailbird 0.83 0.83 0.83 6
troy_mcclure 1.00 1.00 1.00 2
waylon_smithers 0.87 0.68 0.76 19
accuracy 0.92 2133
macro avg 0.87 0.88 0.86 2133
weighted avg 0.93 0.92 0.92 2133
def plot_support_train_amt(report):
precision_list = []
support_list = []
seq = 0
for val in report.values():
if seq != 42:
precision_list.append(val['precision'])
support_list.append(val['support'])
seq+=1
fig, axs = plt.subplots(1, 2, figsize=(12, 6))
axs[0].scatter(precision_list, support_list)
axs[0].set_xlabel('Precision')
axs[0].set_ylabel('Support amount')
axs[1].scatter(precision_list, len_list_)
axs[1].set_xlabel('Precision')
axs[1].set_ylabel('Training data set amount')
According to the below graphics, we can found that there are subtle relationships between precision and support amount (training data amount). When support amount which is the test data of specific classs over than 50, then its precision would around 90% to 100%. Similarly, the amount over than 1000 would reach 90% to 100% precision. So, I think that for the classes' performance, the amount of data is critical.
class_report_customized = classification_report(
y_test_true, y_pred_customized, zero_division=1, output_dict=True
)
plot_support_train_amt(class_report_customized)
from sklearn.metrics import confusion_matrix
import seaborn as sns
def plot_confusion_matrix(cf):
plt.figure(figsize=(20, 15))
sns.heatmap(
cf,
annot=False,
xticklabels=labels,
yticklabels=labels,
)
plt.title("Normalized Confusion Matrix", fontsize=23)
plt.show()
This is the coufusion matrix which indicates the situation where the ground truth predicted by model to others. The color more bright represents the ground truth are more close to the prediction of the model predicted
cf_matrix_customized = confusion_matrix(y_test_true, y_pred_customized, normalize="true")
plot_confusion_matrix(cf_matrix_customized)
def plot_specific_images(a, b):
plt.figure(figsize=(15, 18))
for i, class_1 in enumerate([a, b]):
folder_path = f"{train_dir}/{class_1}"
files = os.listdir(folder_path)
files = [f for f in files if os.path.isfile(os.path.join(folder_path, f))]
random_file = random.choice(files)
plt.subplot(2, 2, i + 1)
img = mpimg.imread(f"{folder_path}/{random_file}")
plt.imshow(img)
plt.axis("off")
plt.title(class_1)
plt.tight_layout()
plt.show()
plot_specific_images('agnes_skinner', 'principal_skinner')
Predict on 15 images to see the results
import cv2
def predict_15_imgs(m):
# Display 15 picture of the dataset with their labels
fig, axes = plt.subplots(
nrows=4, ncols=5, figsize=(15, 12), subplot_kw={"xticks": [], "yticks": []}
)
for i, ax in enumerate(axes.flat):
dd = random.randint(1, 41)
class_name = labels[dd]
folder_path = f"{train_dir}/{class_name}"
files = os.listdir(folder_path)
files = [f for f in files if os.path.isfile(os.path.join(folder_path, f))]
random_file = random.choice(files)
file_name = folder_path + "/" + random_file
img = cv2.imread(file_name)
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
img = cv2.resize(img, (180, 180))
img = img.astype("float32")
img_tensor = tf.expand_dims(tf.convert_to_tensor(img), axis=0)
pred = m.predict(img_tensor)
true_num = np.argmax(pred, axis=1)
ax.imshow(plt.imread(file_name))
ax.set_title(f"True: {class_name}\n Predicted: {labels[true_num[0]]}")
plt.tight_layout()
plt.show()
predict_15_imgs(best_from_custome_customized_model)
1/1 [==============================] - 10s 10s/step 1/1 [==============================] - 0s 35ms/step 1/1 [==============================] - 0s 26ms/step 1/1 [==============================] - 0s 26ms/step 1/1 [==============================] - 0s 41ms/step 1/1 [==============================] - 0s 27ms/step 1/1 [==============================] - 0s 27ms/step 1/1 [==============================] - 0s 24ms/step 1/1 [==============================] - 0s 26ms/step 1/1 [==============================] - 0s 25ms/step 1/1 [==============================] - 0s 27ms/step 1/1 [==============================] - 0s 27ms/step 1/1 [==============================] - 0s 28ms/step 1/1 [==============================] - 0s 27ms/step 1/1 [==============================] - 0s 27ms/step 1/1 [==============================] - 0s 27ms/step 1/1 [==============================] - 0s 40ms/step 1/1 [==============================] - 0s 44ms/step 1/1 [==============================] - 0s 44ms/step 1/1 [==============================] - 0s 47ms/step
y_pred_RN50 = best_from_custome_RN50_model.predict(test_set)
y_pred_RN50 = np.argmax(y_pred_RN50, axis=1)
22/22 [==============================] - 21s 308ms/step
unequal_indices_RN50 = find_unequal_indexes(y_test_true, y_pred_RN50)
print(
f"Total number of mis-classification by ResNet50: {len(unequal_indices_RN50)}, which is {round(len(unequal_indices_RN50)/len(y_test_true)*100, 1)}%"
)
Total number of mis-classification by ResNet50: 254, which is 11.9%
The accuracy provided by the customized model is 88%, which is also good as previous model (customized model). However, look into the each indicators (precision, recall and f1-score), we can find that there is a gap between this model and previous model. The precision lower than 0.8, approximately accounts for one third. Among these, there are even more with precision below 0.6, some even around 0.4
class_report_RN50 = classification_report(
y_test_true, y_pred_RN50, zero_division=1, target_names=labels
)
print(class_report_RN50)
precision recall f1-score support
abraham_grampa_simpson 0.96 0.93 0.95 92
agnes_skinner 0.50 0.60 0.55 5
apu_nahasapeemapetilon 0.94 0.81 0.87 63
barney_gumble 0.70 0.58 0.64 12
bart_simpson 0.96 0.84 0.90 135
carl_carlson 0.37 0.91 0.53 11
charles_montgomery_burns 0.95 0.86 0.90 120
chief_wiggum 0.92 0.91 0.91 100
cletus_spuckler 0.83 0.83 0.83 6
comic_book_guy 0.77 0.83 0.80 48
disco_stu 1.00 0.50 0.67 2
edna_krabappel 0.91 0.83 0.87 47
fat_tony 0.33 1.00 0.50 4
gil 0.67 0.50 0.57 4
groundskeeper_willie 0.52 1.00 0.68 13
homer_simpson 0.86 0.93 0.89 226
kent_brockman 0.84 0.96 0.90 51
krusty_the_clown 0.91 0.96 0.94 122
lenny_leonard 1.00 0.71 0.83 31
lionel_hutz 0.50 1.00 0.67 1
lisa_simpson 0.92 0.88 0.90 136
maggie_simpson 0.89 0.57 0.70 14
marge_simpson 0.90 0.93 0.91 130
martin_prince 0.58 0.88 0.70 8
mayor_quimby 0.68 0.58 0.62 26
milhouse_van_houten 0.97 0.93 0.95 109
miss_hoover 0.67 0.67 0.67 3
moe_szyslak 0.95 0.86 0.90 146
ned_flanders 0.97 0.96 0.96 146
nelson_muntz 1.00 0.70 0.83 37
otto_mann 0.67 1.00 0.80 4
patty_bouvier 0.78 0.88 0.82 8
principal_skinner 0.92 0.90 0.91 120
professor_john_frink 0.56 0.71 0.63 7
rainier_wolfcastle 0.38 1.00 0.56 5
ralph_wiggum 0.90 0.90 0.90 10
selma_bouvier 0.42 0.91 0.57 11
sideshow_bob 0.96 0.88 0.92 89
sideshow_mel 1.00 1.00 1.00 4
snake_jailbird 0.29 1.00 0.44 6
troy_mcclure 1.00 1.00 1.00 2
waylon_smithers 1.00 0.42 0.59 19
accuracy 0.88 2133
macro avg 0.78 0.83 0.78 2133
weighted avg 0.90 0.88 0.89 2133
According to the below graphics, there are subtle relationships between precision and support amount (training data amount). When support amount which is the test data of specific classs over than 75, then its precision would around 80% to 100%. However, for the training data set, the results are quite scattered. We can see that for a precision of 90%, there are data amounts of 175, 375, 750, and 1000, indicating that in this model, imbalanced data doesn't have a significant impact compared to balanced data.
class_report_RN50 = classification_report(
y_test_true, y_pred_RN50, zero_division=1, output_dict=True
)
plot_support_train_amt(class_report_RN50)
cf_matrix_RN50 = confusion_matrix(y_test_true, y_pred_RN50, normalize="true")
plot_confusion_matrix(cf_matrix_RN50)
plot_specific_images('disco_stu', 'otto_mann')
y_pred_MNV2 = best_from_custome_MNV2_model.predict(test_set)
y_pred_MNV2 = np.argmax(y_pred_MNV2, axis=1)
22/22 [==============================] - 5s 178ms/step
unequal_indices_MNV2 = find_unequal_indexes(y_test_true, y_pred_MNV2)
print(
f"Total number of mis-classification by ResNet50: {len(unequal_indices_MNV2)}, which is {round(len(unequal_indices_MNV2)/len(y_test_true)*100, 1)}%"
)
Total number of mis-classification by ResNet50: 722, which is 33.8%
This is the worst-performing model, only 66% accuracy. This is also the reason I train the three models, because I want to find the better pre-trained model could compare to my customized model
class_report_MNV2 = classification_report(
y_test_true, y_pred_MNV2, zero_division=1, target_names=labels
)
print(class_report_MNV2)
precision recall f1-score support
abraham_grampa_simpson 0.70 0.70 0.70 92
agnes_skinner 0.33 0.40 0.36 5
apu_nahasapeemapetilon 0.86 0.78 0.82 63
barney_gumble 0.33 0.50 0.40 12
bart_simpson 0.68 0.81 0.74 135
carl_carlson 0.42 0.91 0.57 11
charles_montgomery_burns 0.78 0.67 0.72 120
chief_wiggum 0.85 0.72 0.78 100
cletus_spuckler 0.09 0.50 0.15 6
comic_book_guy 0.78 0.81 0.80 48
disco_stu 1.00 0.50 0.67 2
edna_krabappel 0.96 0.53 0.68 47
fat_tony 0.50 0.50 0.50 4
gil 0.20 0.25 0.22 4
groundskeeper_willie 0.33 0.62 0.43 13
homer_simpson 0.57 0.78 0.66 226
kent_brockman 0.75 0.80 0.77 51
krusty_the_clown 0.92 0.65 0.76 122
lenny_leonard 0.56 0.32 0.41 31
lionel_hutz 1.00 0.00 0.00 1
lisa_simpson 0.81 0.65 0.72 136
maggie_simpson 0.17 0.07 0.10 14
marge_simpson 0.83 0.73 0.78 130
martin_prince 0.10 0.12 0.11 8
mayor_quimby 0.75 0.12 0.20 26
milhouse_van_houten 0.88 0.56 0.69 109
miss_hoover 1.00 0.33 0.50 3
moe_szyslak 0.60 0.68 0.64 146
ned_flanders 0.55 0.67 0.60 146
nelson_muntz 0.30 0.46 0.37 37
otto_mann 1.00 0.50 0.67 4
patty_bouvier 0.33 0.62 0.43 8
principal_skinner 0.85 0.55 0.67 120
professor_john_frink 0.44 0.57 0.50 7
rainier_wolfcastle 0.67 0.40 0.50 5
ralph_wiggum 0.33 0.70 0.45 10
selma_bouvier 0.22 0.36 0.28 11
sideshow_bob 0.80 0.79 0.80 89
sideshow_mel 0.12 0.50 0.20 4
snake_jailbird 0.33 0.33 0.33 6
troy_mcclure 0.67 1.00 0.80 2
waylon_smithers 1.00 0.11 0.19 19
accuracy 0.66 2133
macro avg 0.60 0.54 0.52 2133
weighted avg 0.71 0.66 0.67 2133
Same situation with previous model (ResNet50), the support amount more than 50 have relatively high precision from 70% to 90%, then for the training data set, the results are quite scattered. Indicating that imbalanced data doesn't have a significant impact compared to balanced data.
class_report_MNV2 = classification_report(
y_test_true, y_pred_MNV2, zero_division=1, output_dict=True
)
plot_support_train_amt(class_report_MNV2)
This is the funnest part that this model predicted all cases should be "Lionel Hutz" to "Principal Skinner", it means the model has 100% misclassication on "Lionel Hutz".
cf_matrix_MNV2 = confusion_matrix(y_test_true, y_pred_MNV2, normalize="true")
plot_confusion_matrix(cf_matrix_MNV2)
plot_specific_images('lionel_hutz', 'principal_skinner')
These numbers (Training Parameters, training time, losses, accuracies) are copy from the previous execution results.
import pandas as pd
comparision = {
"Training Parameters": ['105,705,034', '339,114', '25,728,938'],
"Model training time": ['91m 53s', '40m 34s', '34m 14s'],
"Lowest loss": ['0.2658', '1.2937', '0.3887'],
"with accuracy": ['93%', '63%', '89%'],
"Highest accuracy": ['93%', '64%', '89%'],
"with loss": ['0.2981', '1.3655', '0.3887'],
}
com_df = pd.DataFrame(comparision)
com_df
| Training Parameters | Model training time | Lowest loss | with accuracy | Highest accuracy | with loss | |
|---|---|---|---|---|---|---|
| 0 | 105,705,034 | 91m 53s | 0.2658 | 93% | 93% | 0.2981 |
| 1 | 339,114 | 40m 34s | 1.2937 | 63% | 64% | 1.3655 |
| 2 | 25,728,938 | 34m 14s | 0.3887 | 89% | 89% | 0.3887 |
import matplotlib.pyplot as plt
import numpy as np
comparison = {
"Training Parameters": ['105,705,034', '339,114', '25,728,938'],
"Model training time": ['91', '40', '34'],
"Highest validation accuracy": ['93%', '64%', '89%'],
"with loss": ['0.2981', '1.3655', '0.3887'],
}
parameters = [int(param.replace(',', '')) for param in comparison["Training Parameters"]]
training_time = [int(time.split(' ')[0]) for time in comparison["Model training time"]]
lowest_loss = [float(loss) for loss in comparison["Lowest loss"]]
accuracy = [int(acc.replace('%', '')) for acc in comparison["with accuracy"]]
According to the graphics, we can see that these three models are quite different from each other based on their attributes. They have varying execution times, parameters, and performance
fig = plt.figure(figsize=(10, 6))
ax = fig.add_subplot(111, projection='3d')
ax.scatter(parameters, training_time, lowest_loss, c=accuracy, cmap='viridis', s=100)
ax.set_xlabel('Training Parameters')
ax.set_ylabel('Model training time (minutes)')
ax.set_zlabel('Loss')
cbar = plt.colorbar(ax.scatter(parameters, training_time, lowest_loss, c=accuracy, cmap='viridis'))
cbar.set_label('Accuracy (%)')
plt.title('4D Comparison')
plt.show()
I trained the model with GPU, however they took me several hours to train (I didn't just train each one once, maybe 3-4 times for each model), so based on the graphics I would definitely choose the shortest training time but with relatively better performance, which is the pre-trained model ResNet 50.
plt.figure(figsize=(8, 6))
contour = plt.tricontourf(training_time, lowest_loss, accuracy, cmap='viridis')
plt.colorbar(contour, label='Accuracy (%)')
plt.xlabel('Model training time (minutes)')
plt.ylabel('Loss')
plt.title('2D Contour Plot of 3D Data')
plt.grid(True)
plt.show()